/**
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#ifndef EXAMPLES_ACTIVATION_SOFTMAXGRADFRONT_KERNEL_H
#define EXAMPLES_ACTIVATION_SOFTMAXGRADFRONT_KERNEL_H
#include "kernel_operator.h"
namespace MyCustomKernel {
constexpr int32_t BUFFER_NUM = 1;
constexpr uint32_t FLOAT_NUM_OF_SINGEL_BLOCK = 8;
constexpr uint32_t BASIC_BLOCK_ROW_FACTOR = 8;
constexpr uint32_t BASIC_BLOCK_COLUMN_FACTOR = 64;
constexpr uint32_t BASIC_BLOCK_MAX_COLUMN_LENGTH = 2048;
struct VecTiling {
    uint32_t columnLength = 0;
    uint32_t rowLength = 0;
    uint32_t sharedTmpBufferSize = 0;
    uint32_t usedBlockDim = 0;
    uint32_t coreRowNum = 0;
    uint32_t tailCoreRowNum = 0;
    uint32_t singleLoopCoreRowNum = 0;
    uint32_t singleCoreLoopCount = 0;
    uint32_t singleCoreLoopTail = 0;
    uint32_t tailCoreSingleLoopCoreRowNum = 0;
    uint32_t tailCoreSingleCoreLoopCount = 0;
    uint32_t tailCoreSingleCoreLoopTail = 0;
    SoftMaxTiling softmaxTilingData;
};

class KernelSoftmax {
public:
    __aicore__ inline KernelSoftmax() {}
    __aicore__ inline void InitTiling(const VecTiling& tilingData)
    {
        rowLength = tilingData.rowLength;
        sharedTmpBufferSize = tilingData.sharedTmpBufferSize;
        columnLength = tilingData.columnLength;
        usedBlockDim = tilingData.usedBlockDim;
        coreRowNum = tilingData.coreRowNum;
        softmaxTiling = tilingData.softmaxTilingData;
        singleLoopCoreRowNum = tilingData.singleLoopCoreRowNum;
        singleCoreLoopCount = tilingData.singleCoreLoopCount;
        leftRow = tilingData.singleCoreLoopTail;
        tailCoreSingleLoopCoreRowNum = tilingData.tailCoreSingleLoopCoreRowNum;
        tailCoreSingleCoreLoopCount = tilingData.tailCoreSingleCoreLoopCount;
        tailCoreSingleCoreLoopTail = tilingData.tailCoreSingleCoreLoopTail;
    }
    __aicore__ inline void Init(GM_ADDR x, GM_ADDR y, GM_ADDR z, const VecTiling& tiling)
    {
        ASSERT(AscendC::GetBlockNum() != 0 && "block dim can not be zero!");
        InitTiling(tiling);

        if (AscendC::GetBlockIdx() == this->usedBlockDim) { // tail core
            this->singleLoopCoreRowNum = this->tailCoreSingleLoopCoreRowNum;
            this->singleCoreLoopCount = this->tailCoreSingleCoreLoopCount;
            this->leftRow = this->tailCoreSingleCoreLoopTail;
        }

        this->blockLength = this->coreRowNum * this->columnLength;
        uint32_t offset1 = this->blockLength * AscendC::GetBlockIdx();
        uint32_t offset2 = this->coreRowNum * FLOAT_NUM_OF_SINGEL_BLOCK * AscendC::GetBlockIdx();

        xGm.SetGlobalBuffer((__gm__ float*)x + offset1, this->blockLength);
        yGm.SetGlobalBuffer((__gm__ float*)y + offset1, this->blockLength);
        zGm.SetGlobalBuffer((__gm__ float*)z + offset2, this->coreRowNum * FLOAT_NUM_OF_SINGEL_BLOCK);

        this->tileLength = this->singleLoopCoreRowNum * this->columnLength;
        pipe.InitBuffer(queueX, BUFFER_NUM, this->tileLength * sizeof(float));
        pipe.InitBuffer(queueY, BUFFER_NUM, this->tileLength * sizeof(float));
        pipe.InitBuffer(queueZ, BUFFER_NUM, this->singleLoopCoreRowNum * FLOAT_NUM_OF_SINGEL_BLOCK * sizeof(float));

        pipe.InitBuffer(sharedTmpBuffer, sharedTmpBufferSize);
    }

    __aicore__ inline void Process()
    {
        if (AscendC::GetBlockIdx() > usedBlockDim) {
            return;
        }

        for (int32_t i = 0; i < this->singleCoreLoopCount; i++) {
            CopyIn(i, this->singleLoopCoreRowNum);
            Compute(i, this->singleLoopCoreRowNum);
            CopyOut(i, this->singleLoopCoreRowNum);
        }

        if (this->leftRow > 0) {
            CopyIn(this->singleCoreLoopCount, this->leftRow);
            Compute(this->singleCoreLoopCount, this->leftRow);
            CopyOut(this->singleCoreLoopCount, this->leftRow);
        }
    }

private:
    __aicore__ inline void CopyIn(int32_t progress, uint32_t rowNum)
    {
        AscendC::LocalTensor<float> xLocal = queueX.AllocTensor<float>();
        AscendC::LocalTensor<float> yLocal = queueY.AllocTensor<float>();
        AscendC::DataCopy(xLocal, xGm[progress * this->tileLength], rowNum * this->columnLength);
        AscendC::DataCopy(yLocal, yGm[progress * this->tileLength], rowNum * this->columnLength);
        queueX.EnQue(xLocal);
        queueY.EnQue(yLocal);
    }

    __aicore__ inline void Compute(int32_t progressm, uint32_t rowNum)
    {
        AscendC::LocalTensor<float> xLocal = queueX.DeQue<float>();
        AscendC::LocalTensor<float> yLocal = queueY.DeQue<float>();
        AscendC::LocalTensor<float> zLocal = queueZ.AllocTensor<float>();
        AscendC::LocalTensor<uint8_t> tmpBuffer = sharedTmpBuffer.Get<uint8_t>();

        AscendC::SoftMaxShapeInfo srcShape = { rowNum, this->columnLength, rowNum, this->columnLength };
        if (this->singleLoopCoreRowNum % BASIC_BLOCK_ROW_FACTOR == 0 &&
            this->columnLength % BASIC_BLOCK_COLUMN_FACTOR == 0 &&
            this->columnLength < BASIC_BLOCK_MAX_COLUMN_LENGTH) {
            AscendC::SoftmaxGradFront<float, true>(zLocal, xLocal, yLocal, tmpBuffer, softmaxTiling, srcShape);
        } else {
            AscendC::SoftmaxGradFront<float>(zLocal, xLocal, yLocal, tmpBuffer, softmaxTiling, srcShape);
        }

        queueZ.EnQue(zLocal);
        queueX.FreeTensor(xLocal);
        queueY.FreeTensor(yLocal);
    }

    __aicore__ inline void CopyOut(int32_t progress, uint32_t rowNum)
    {
        AscendC::LocalTensor<float> zLocal = queueZ.DeQue<float>();
        AscendC::DataCopy(zGm[progress * this->singleLoopCoreRowNum * FLOAT_NUM_OF_SINGEL_BLOCK], zLocal,
            rowNum * FLOAT_NUM_OF_SINGEL_BLOCK);
        queueZ.FreeTensor(zLocal);
    }

private:
    AscendC::TPipe pipe;
    AscendC::TBuf<AscendC::TPosition::VECCALC> sharedTmpBuffer;
    AscendC::TQue<AscendC::TPosition::VECIN, BUFFER_NUM> queueX;
    AscendC::TQue<AscendC::TPosition::VECIN, BUFFER_NUM> queueY;
    AscendC::TQue<AscendC::TPosition::VECOUT, BUFFER_NUM> queueZ;
    AscendC::GlobalTensor<float> xGm;
    AscendC::GlobalTensor<float> yGm;
    AscendC::GlobalTensor<float> zGm;

    uint32_t blockLength = 0;
    uint32_t usedBlockDim = 0;
    uint32_t rowLength = 0;
    uint32_t columnLength = 0;
    uint32_t coreRowNum = 0;
    uint32_t tileLength = 0;
    uint32_t msTileLength = 0;
    uint32_t loopCount = 0;
    uint32_t sharedTmpBufferSize = 0;
    uint32_t singleLoopCoreRowNum = 0;
    uint32_t singleCoreLoopCount = 0;
    uint32_t leftRow = 0;
    uint32_t tailCoreSingleLoopCoreRowNum = 0;
    uint32_t tailCoreSingleCoreLoopCount = 0;
    uint32_t tailCoreSingleCoreLoopTail = 0;
    SoftMaxTiling softmaxTiling;
};
}
#endif // EXAMPLES_ACTIVATION_SOFTMAXGRADFRONT_KERNEL_H